import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.datasets import make_moons, load_digits
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
import numba as nb
class Block(nn.Module):
def __init__(self, input_size, hidden_size, last):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.split = self.input_size // 2
self.s_net = nn.Sequential(
nn.Linear(self.split, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.split),
nn.Tanh(),
)
self.t_net = nn.Sequential(
nn.Linear(self.split, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.split),
)
if last:
self.Q = torch.eye(self.input_size)
else:
self.Q, _ = torch.linalg.qr(torch.randn(size=(self.input_size, self.input_size)))
def encode(self, x):
x1 = x[:, :self.split]
x2 = x[:, self.split:]
s = torch.exp(self.s_net(x1))
t = self.t_net(x1)
x2 = s * x2 + t
return torch.cat((x1, x2), dim=1) @ self.Q
def decode(self, x):
x = x @ self.Q.T
x1 = x[:, :self.split]
x2 = x[:, self.split:]
s = torch.exp(self.s_net(x1))
t = self.t_net(x1)
x2 = (x2 - t) / s
return torch.cat((x1, x2), dim=1)
def loss_fwd(self, x):
x1 = x[:, :self.split]
x2 = x[:, self.split:]
s_wig = self.s_net(x1)
s = torch.exp(s_wig)
t = self.t_net(x1)
x2 = s * x2 + t
return torch.cat((x1, x2), dim=1) @ self.Q, s_wig
class RealNVP(nn.Module):
def __init__(self, input_size, hidden_size, blocks):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.blocks = blocks
self.couplings = nn.Sequential(*[
Block(self.input_size, self.hidden_size, i == self.blocks - 1)
for i in range(self.blocks)
])
self.sampler = MultivariateNormal(torch.zeros(self.input_size), torch.eye(self.input_size))
def encode(self, x):
for b in self.couplings:
x = b.encode(x)
return x
def decode(self, x):
for b in self.couplings[::-1]:
x = b.decode(x)
return x
def loss(self, x):
s_vals = torch.zeros(len(x))
for b in self.couplings:
x, s_wig = b.loss_fwd(x)
s_vals += torch.sum(s_wig, dim=1)
return torch.mean(torch.sum(x * x, dim=1) / 2 - s_vals)
def sample(self, n_samples):
with torch.no_grad():
codes = self.sampler.sample((n_samples,))
generated = self.decode(codes)
return generated
def print(self):
for name, param in self.named_parameters():
print(f"{name}, {param.shape}, grad: {param.requires_grad}")
def train_inn(model, X_train, n_epochs, plot_loss=False, X_eval=None, data_loader=None):
optimizer = torch.optim.Adam(model.parameters())
if data_loader is None:
if X_eval is not None:
eval_losses = []
losses = []
for i in tqdm(range(n_epochs)):
loss = model.loss(X_train)
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.item())
if X_eval is not None:
model.eval()
with torch.no_grad():
eval_losses.append(model.loss(X_eval).item())
model.train()
if plot_loss:
plt.figure(figsize=(12, 8))
if X_eval is not None:
plt.plot(range(n_epochs), losses, label="Train Loss")
plt.plot(range(n_epochs), eval_losses, label="Eval Loss")
plt.title("Loss during Training")
plt.legend()
else:
plt.plot(range(n_epochs), losses)
plt.title("Train Loss", size=18)
plt.xlabel("Epoch", size=16)
plt.ylabel("NLL", size=16)
plt.show()
else:
losses = []
for i in tqdm(range(n_epochs)):
epoch_loss = 0.0
for X, y in data_loader:
loss = model.loss(X.reshape(-1, 64))
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
losses.append(epoch_loss)
if plot_loss:
plt.figure(figsize=(12, 8))
plt.plot(range(n_epochs), losses)
plt.title("Train Loss", size=18)
plt.xlabel("Epoch", size=16)
plt.ylabel("NLL", size=16)
plt.show()
N_train = 1000
n_epochs = 1000
X_train, _ = make_moons(N_train, shuffle=True, noise=0.1)
X_train = torch.tensor(X_train).to(torch.float)
X_eval, _ = make_moons(1000, shuffle=True, noise=0.1)
X_eval = torch.tensor(X_eval).to(torch.float)
model = RealNVP(2, 16, 4)
train_inn(model, X_train, n_epochs, plot_loss=True, X_eval=X_eval)
model.eval()
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
plt.figure(figsize=(8, 8))
plt.scatter(model.encode(X_train).detach().numpy()[:,0], model.encode(X_train).detach().numpy()[:,1], label="Encodings", color="C0")
plt.scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1")
plt.title("Latent Space")
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.legend()
plt.show()
synths = model.sample(1000)
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
axes[0].scatter(X_train[:,0], X_train[:,1])
axes[0].set_title("Training Data")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths[:,0], synths[:,1])
axes[1].set_title("Synthetic Data")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
0%| | 0/1000 [00:00<?, ?it/s]
@nb.njit(nopython=True)
def kernel_fct(x1, x2, h):
return np.exp(-np.sum((x1 - x2) ** 2) / (2 * h))
@nb.njit(nopython=True, parallel=True)
def MMD(pred, truth, bandwidths=np.array([0.4, 0.8, 1.6])):
M = len(pred)
N = len(truth)
term1 = 0
term2 = 0
term3 = 0
for h in bandwidths:
t1 = 0
t2 = 0
t3 = 0
for i in nb.prange(M):
for j in nb.prange(M):
t1 += kernel_fct(pred[i], pred[j], h)
t1 -= M
t1 /= (M * (M - 1))
term1 += t1
for i in nb.prange(N):
for j in nb.prange(N):
t2 += kernel_fct(truth[i], truth[j], h)
t2 -= N
t2 /= (N * (N - 1))
term2 += t2
for i in nb.prange(M):
for j in nb.prange(N):
t3 += kernel_fct(pred[i], truth[j], h)
t3 /= (M * N)
term3 += t3
return (term1 + term2 - 2 * term3) / len(bandwidths)
x1 = np.random.uniform(0, 1, size=(10, 4))
x2 = np.random.uniform(0, 1, size=(10, 4))
print(MMD(x1, x2))
c:\Users\gooog\anaconda3\envs\ml_env\Lib\site-packages\numba\core\decorators.py:282: RuntimeWarning: nopython is set for njit and is ignored
warnings.warn('nopython is set for njit and is ignored', RuntimeWarning)
0.03931266188353252
hidden_size = 16
n_blocks = 4
MMD_vals = np.zeros((3, 3))
X_test, _ = make_moons(1000, shuffle=True, noise=0.1)
X_test = torch.tensor(X_test).to(torch.float)
X_eval, _ = make_moons(1000, shuffle=True, noise=0.1)
X_eval = torch.tensor(X_eval).to(torch.float)
models = []
for i, N_train in enumerate([100, 1000, 10000]):
X_train, _ = make_moons(N_train, shuffle=True, noise=0.1)
X_train = torch.tensor(X_train).to(torch.float)
for j, n_epochs in enumerate([100, 1000, 10000]):
print(f"N_train = {N_train}, n_epochs = {n_epochs}")
model = RealNVP(2, hidden_size, n_blocks)
train_inn(model, X_train, n_epochs, plot_loss=True, X_eval=X_eval)
model.eval()
with torch.no_grad():
plt.figure(figsize=(8, 8))
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
plt.scatter(model.encode(X_train).detach().numpy()[:,0], model.encode(X_train).detach().numpy()[:,1], label="Encodings", color="C0")
plt.scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
plt.title("Latent Space")
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.legend()
plt.show()
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
synths = model.sample(1000).numpy()
MMD_vals[i, j] = MMD(synths, X_test.numpy())
axes[0].scatter(X_test[:,0], X_test[:,1])
axes[0].set_title("Test Data")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths[:,0], synths[:,1])
axes[1].set_title(f"Synthetic Data, MMD = {MMD_vals[i, j]}")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
print(" ")
print(" ")
print(" ")
models.append(model)
MMD_sorted, models_sorted = zip(*sorted(zip(MMD_vals.flatten(), models), key=lambda x: x[0]))
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
for i, model in enumerate(models_sorted):
synths = model.sample(1000).numpy()
axes[i//3, i%3].scatter(synths[:,0], synths[:,1])
axes[i//3, i%3].set_title(f"MMD = {np.round(MMD_sorted[i], 4)}")
axes[i//3, i%3].set_xlabel("X")
axes[i//3, i%3].set_ylabel("Y")
plt.tight_layout()
plt.show()
N_train = 100, n_epochs = 100
0%| | 0/100 [00:00<?, ?it/s]
N_train = 100, n_epochs = 1000
0%| | 0/1000 [00:00<?, ?it/s]
N_train = 100, n_epochs = 10000
0%| | 0/10000 [00:00<?, ?it/s]
N_train = 1000, n_epochs = 100
0%| | 0/100 [00:00<?, ?it/s]
N_train = 1000, n_epochs = 1000
0%| | 0/1000 [00:00<?, ?it/s]
N_train = 1000, n_epochs = 10000
0%| | 0/10000 [00:00<?, ?it/s]
N_train = 10000, n_epochs = 100
0%| | 0/100 [00:00<?, ?it/s]
N_train = 10000, n_epochs = 1000
0%| | 0/1000 [00:00<?, ?it/s]
N_train = 10000, n_epochs = 10000
0%| | 0/10000 [00:00<?, ?it/s]
plt.figure(figsize=(8, 8))
img = plt.imshow(MMD_vals)
plt.ylabel("N_train")
plt.xlabel("n_epochs")
plt.yticks(range(3), [100, 1000, 10000])
plt.xticks(range(3), [100, 1000, 10000])
for i in range(3):
for j in range(3):
if MMD_vals[i, j] > 0:
plt.text(j-0.2, i+0.025, np.round(MMD_vals[i, j], 4), color="black")
else:
plt.text(j-0.2, i+0.025, np.round(MMD_vals[i, j], 4), color="white")
plt.colorbar(img, fraction=0.0457, pad=0.04)
plt.show()
Analyzing the performance for various hyperparameters, we see that generally larger training sets tend to increase performance. For the training time, we observe that 10000 epochs yields quite a significant improvement to 1000 for the large training set, but results in worse performance for the smallest dataset. This is likely due to overfitting. The best model is the one trained on 10000 datapoints for 10000 epochs.
The code distributions of all models follow a standart normal (at least roughly). For 100 epochs, the two moons are usually very clearly visible in the latent distribution. For 1000 epochs, the distribution is a lot closer to a standard normal, but we can still detect a slight gap dividing the two moons. For 10000 epochs, the distributions are visually indistinguishable.
Looking at the synthetic data, we see that the models that converged well indeed seem to be able to accurately reproduce the two moons, while the worse models capture the general shape, but fail to capture the details. By plotting synthetic data of the models sorted by increasing MMD, we see that smaller MMD values indeed correspond to visually better synthetic data, though not exactly capturing the same order, that we would assign the various distributions (e.g. 4-th place should be 2-nd).
def GMM(n_samples, labels=None):
"""
labels: list(int) of len 6, e.g. [0, 0, 1, 2, 1, 2]
numbers mark the classes of the modes, starting with the right one
and rotating counter-clockwise
"""
if labels is None:
x = np.zeros((6, n_samples, 2))
for n in range(6):
m = np.array([np.cos(np.pi * n / 3), np.sin(np.pi * n / 3)])
s = np.random.multivariate_normal(m, np.eye(2) * 0.01, n_samples)
x[n] = s
x = x.reshape(6 * n_samples, 2)
shuffle_inds = np.arange(6 * n_samples)
np.random.shuffle(shuffle_inds)
x = x[shuffle_inds]
return x
else:
x = np.zeros((6, n_samples, 2))
y = np.zeros((6, n_samples))
for n in range(6):
m = np.array([np.cos(np.pi * n / 3), np.sin(np.pi * n / 3)])
s = np.random.multivariate_normal(m, np.eye(2) * 0.01, n_samples)
x[n] = s
y[n] = np.ones(n_samples, dtype=int) * labels[n]
x = x.reshape(6 * n_samples, 2)
y = y.reshape(6 * n_samples)
shuffle_inds = np.arange(6 * n_samples)
np.random.shuffle(shuffle_inds)
x = x[shuffle_inds]
y = y[shuffle_inds].astype(np.int64)
return x, y
gmm_data = GMM(100)
plt.figure(figsize=(8, 8))
plt.scatter(gmm_data[:,0], gmm_data[:,1])
plt.title("GMM")
plt.xlabel("X")
plt.ylabel("Y")
plt.show()
hidden_size = 16
n_blocks = 4
MMD_vals = np.zeros((3, 3))
X_test = GMM(1000)
X_test = torch.tensor(X_test).to(torch.float)
X_eval = GMM(1000)
X_eval = torch.tensor(X_eval).to(torch.float)
models = []
for i, N_train in enumerate([100, 1000, 10000]):
X_train = GMM(N_train)
X_train = torch.tensor(X_train).to(torch.float)
for j, n_epochs in enumerate([100, 1000, 10000]):
print(f"N_train = {N_train}, n_epochs = {n_epochs}")
model = RealNVP(2, hidden_size, n_blocks)
train_inn(model, X_train, n_epochs, plot_loss=True, X_eval=X_eval)
model.eval()
with torch.no_grad():
plt.figure(figsize=(8, 8))
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
plt.scatter(model.encode(X_train).detach().numpy()[:,0], model.encode(X_train).detach().numpy()[:,1], label="Encodings", color="C0")
plt.scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
plt.title("Latent Space")
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.legend()
plt.show()
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
synths = model.sample(1000).numpy()
MMD_vals[i, j] = MMD(synths, X_test.numpy())
axes[0].scatter(X_test[:,0], X_test[:,1])
axes[0].set_title("Test Data")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths[:,0], synths[:,1])
axes[1].set_title(f"Synthetic Data, MMD = {MMD_vals[i, j]}")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
print(" ")
print(" ")
print(" ")
models.append(model)
MMD_sorted, models_sorted = zip(*sorted(zip(MMD_vals.flatten(), models), key=lambda x: x[0]))
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
for i, model in enumerate(models_sorted):
synths = model.sample(1000).numpy()
axes[i//3, i%3].scatter(synths[:,0], synths[:,1])
axes[i//3, i%3].set_title(f"MMD = {np.round(MMD_sorted[i], 4)}")
axes[i//3, i%3].set_xlabel("X")
axes[i//3, i%3].set_ylabel("Y")
plt.tight_layout()
plt.show()
N_train = 100, n_epochs = 100
0%| | 0/100 [00:00<?, ?it/s]
N_train = 100, n_epochs = 1000
0%| | 0/1000 [00:00<?, ?it/s]
N_train = 100, n_epochs = 10000
0%| | 0/10000 [00:00<?, ?it/s]
N_train = 1000, n_epochs = 100
0%| | 0/100 [00:00<?, ?it/s]
N_train = 1000, n_epochs = 1000
0%| | 0/1000 [00:00<?, ?it/s]
N_train = 1000, n_epochs = 10000
0%| | 0/10000 [00:00<?, ?it/s]
N_train = 10000, n_epochs = 100
0%| | 0/100 [00:00<?, ?it/s]
N_train = 10000, n_epochs = 1000
0%| | 0/1000 [00:00<?, ?it/s]
N_train = 10000, n_epochs = 10000
0%| | 0/10000 [00:00<?, ?it/s]
We make similar observations to the two moons case, especially that the best model once again was trained on 10000 training points for 10000 epochs. In this case, the trend seems to be even more clear, that bigger training datasets and longer training time both directly improve performance.
As before, the models trained for only 100 epochs clearly show the 6 modes of the GMM in the latent space, which just seem to be centered around a standard normal. For longer training times, the overlap betweend the code distribution and the standard normal becomes much larger, but we see less clean results as for the two moons, with the code distributions usually bleeding outside of the standard normal along some directions.
Even though the MMD values are generally larger than for the two moons, the synthetic data in this case seems more convincing with 6 out of 9 models being able to accurately capture the shape of the data distribution. We observe some bleeding between the different nodes though, but it seems that the hexagon is easier or at least not harder to learn than the two moons. The MMD values correspond well to visually good synthetic data.
class CondBlock(nn.Module):
def __init__(self, input_size, hidden_size, condition_size, last):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.condition_size = condition_size
self.split = self.input_size // 2
self.s_net = nn.Sequential(
nn.Linear(self.split + self.condition_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.split),
nn.Tanh(),
)
self.t_net = nn.Sequential(
nn.Linear(self.split + self.condition_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.split),
)
if last:
self.Q = torch.eye(self.input_size)
else:
self.Q, _ = torch.linalg.qr(torch.randn(size=(self.input_size, self.input_size)))
def encode(self, x, condition):
x1 = x[:, :self.split]
x2 = x[:, self.split:]
s = torch.exp(self.s_net(torch.cat((x1, condition), dim=1)))
t = self.t_net(torch.cat((x1, condition), dim=1))
x2 = s * x2 + t
return torch.cat((x1, x2), dim=1) @ self.Q
def decode(self, x, condition):
x = x @ self.Q.T
x1 = x[:, :self.split]
x2 = x[:, self.split:]
s = torch.exp(self.s_net(torch.cat((x1, condition), dim=1)))
t = self.t_net(torch.cat((x1, condition), dim=1))
x2 = (x2 - t) / s
return torch.cat((x1, x2), dim=1)
def loss_fwd(self, x, condition):
x1 = x[:, :self.split]
x2 = x[:, self.split:]
s_wig = self.s_net(torch.cat((x1, condition), dim=1))
s = torch.exp(s_wig)
t = self.t_net(torch.cat((x1, condition), dim=1))
x2 = s * x2 + t
return torch.cat((x1, x2), dim=1) @ self.Q, s_wig
class CondRealNVP(nn.Module):
def __init__(self, input_size, hidden_size, blocks, condition_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.condition_size = condition_size
self.blocks = blocks
self.couplings = nn.Sequential(*[
CondBlock(self.input_size, self.hidden_size, self.condition_size, i == self.blocks - 1)
for i in range(self.blocks)
])
self.sampler = MultivariateNormal(torch.zeros(self.input_size), torch.eye(self.input_size))
def encode(self, x, condition):
for b in self.couplings:
x = b.encode(x, condition)
return x
def decode(self, x, condition):
for b in self.couplings[::-1]:
x = b.decode(x, condition)
return x
def loss(self, x, condition):
s_vals = torch.zeros(len(x))
for b in self.couplings:
x, s_wig = b.loss_fwd(x, condition)
s_vals += torch.sum(s_wig, dim=1)
return torch.mean(torch.sum(x * x, dim=1) / 2 - s_vals)
def sample(self, n_samples, conditions):
with torch.no_grad():
generated = np.zeros((len(conditions), n_samples, self.input_size))
for i, condition in enumerate(conditions):
cond_stacked = torch.tensor([condition.tolist()] * n_samples).to(torch.float)
codes = self.sampler.sample((n_samples,))
generated[i] = self.decode(codes, cond_stacked)
return generated.reshape(len(conditions) * n_samples, self.input_size)
def print(self):
for name, param in self.named_parameters():
print(f"{name}, {param.shape}, grad: {param.requires_grad}")
def train_cinn(model, X_train, y_train, n_epochs, plot_loss=False, eval_data=None, data_loader=None):
optimizer = torch.optim.Adam(model.parameters())
if data_loader is None:
if eval_data is not None:
X_eval, y_eval = eval_data
eval_losses = []
losses = []
for i in tqdm(range(n_epochs)):
loss = model.loss(X_train, y_train)
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.item())
if eval_data is not None:
model.eval()
with torch.no_grad():
eval_losses.append(model.loss(X_eval, y_eval).item())
model.train()
if plot_loss:
plt.figure(figsize=(12, 8))
if eval_data is not None:
plt.plot(range(n_epochs), losses, label="Train Loss")
plt.plot(range(n_epochs), eval_losses, label="Eval Loss")
plt.title("Loss during Training")
plt.legend()
else:
plt.plot(range(n_epochs), losses)
plt.title("Train Loss", size=18)
plt.xlabel("Epoch", size=16)
plt.ylabel("NLL", size=16)
plt.show()
else:
losses = []
for i in tqdm(range(n_epochs)):
epoch_loss = 0.0
for X, y in data_loader:
one_hot_y = nn.functional.one_hot(y.to(torch.int64), 10)
loss = model.loss(X.reshape(-1, 64), one_hot_y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
losses.append(epoch_loss)
if plot_loss:
plt.figure(figsize=(12, 8))
plt.plot(range(n_epochs), losses)
plt.title("Train Loss", size=18)
plt.xlabel("Epoch", size=16)
plt.ylabel("NLL", size=16)
plt.show()
hidden_size = 16
n_blocks = 4
N_train = 1000
n_epochs = 1000
X_test, test_flags = make_moons(1000, shuffle=True, noise=0.1)
X_test = torch.tensor(X_test).to(torch.float)
y_test = nn.functional.one_hot(torch.tensor(test_flags)).to(torch.float)
X_eval, y_eval = make_moons(1000, shuffle=True, noise=0.1)
X_eval = torch.tensor(X_eval).to(torch.float)
y_eval = nn.functional.one_hot(torch.tensor(y_eval)).to(torch.float)
eval_data = (X_eval, y_eval)
X_train, y_train = make_moons(N_train, shuffle=True, noise=0.1)
X_train = torch.tensor(X_train).to(torch.float)
y_train = nn.functional.one_hot(torch.tensor(y_train)).to(torch.float)
model = CondRealNVP(2, hidden_size, n_blocks, 2)
train_cinn(model, X_train, y_train, n_epochs, plot_loss=True, eval_data=eval_data)
model.eval()
conditions = torch.stack([torch.tensor([1, 0], dtype=torch.float), torch.tensor([0, 1], dtype=torch.float)])
with torch.no_grad():
plt.figure(figsize=(8, 8))
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
plt.scatter(model.encode(X_train, y_train).detach().numpy()[:,0], model.encode(X_train, y_train).detach().numpy()[:,1], label="Encodings", color="C0")
plt.scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
plt.title("Latent Space")
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.legend()
plt.show()
synths = model.sample(1000, conditions)
for i, condition in enumerate(conditions):
if i == 0:
synths_cond = synths[:1000]
else:
synths_cond = synths[1000:]
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
MMD_val = MMD(synths_cond, X_test[test_flags == i].numpy())
axes[0].scatter(X_test[test_flags == i].numpy()[:,0], X_test[test_flags == i].numpy()[:,1])
axes[0].set_title(f"Test Data, y = {i}")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths_cond[:,0], synths_cond[:,1])
axes[1].set_title(f"Synthetic Data, MMD = {MMD_val}")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
synths = model.sample(500, conditions)
MMD_val = MMD(synths, X_test.numpy())
axes[0].scatter(X_test[:,0], X_test[:,1])
axes[0].set_title("Test Data")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths[:,0], synths[:,1])
axes[1].set_title(f"Synthetic Data, MMD = {MMD_val}")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
0%| | 0/1000 [00:00<?, ?it/s]
We observe, that the latent space follows a standard normal distribution quite well. Furthermore, the model very clearly distinguishes between the two conditions. The synthetic data seems visually more convincing than for the unconditional model, both in conditional mode and for the marginal distribution, even though the MMD is slightly higher than for the best model from above. Nonetheless, we used a lot less data and training time for this model, and it seems likely, that using 10000 datapoints and epochs would even further increase performance.
hidden_size = 16
n_blocks = 4
N_train = 1000
n_epochs = 1000
X_test, test_flags = GMM(1000, [1, 0, 1, 0, 1, 1])
X_test = torch.tensor(X_test).to(torch.float)
y_test = nn.functional.one_hot(torch.tensor(test_flags)).to(torch.float)
X_eval, y_eval = GMM(1000, [1, 0, 1, 0, 1, 1])
X_eval = torch.tensor(X_eval).to(torch.float)
y_eval = nn.functional.one_hot(torch.tensor(y_eval)).to(torch.float)
eval_data = (X_eval, y_eval)
X_train, y_train = GMM(N_train, [1, 0, 1, 0, 1, 1])
X_train = torch.tensor(X_train).to(torch.float)
y_train = nn.functional.one_hot(torch.tensor(y_train)).to(torch.float)
model = CondRealNVP(2, hidden_size, n_blocks, 2)
train_cinn(model, X_train, y_train, n_epochs, plot_loss=True, eval_data=eval_data)
model.eval()
conditions = torch.stack([torch.tensor([1, 0], dtype=torch.float), torch.tensor([0, 1], dtype=torch.float)])
with torch.no_grad():
plt.figure(figsize=(8, 8))
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
plt.scatter(model.encode(X_train, y_train).detach().numpy()[:,0], model.encode(X_train, y_train).detach().numpy()[:,1], label="Encodings", color="C0")
plt.scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
plt.title("Latent Space")
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.legend()
plt.show()
synths = model.sample(1000, conditions)
for i, condition in enumerate(conditions):
if i == 0:
synths_cond = synths[:1000]
else:
synths_cond = synths[1000:]
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
MMD_val = MMD(synths_cond, X_test[test_flags == i].numpy())
axes[0].scatter(X_test[test_flags == i].numpy()[:,0], X_test[test_flags == i].numpy()[:,1])
axes[0].set_title(f"Test Data, y = {i}")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths_cond[:,0], synths_cond[:,1])
axes[1].set_title(f"Synthetic Data, MMD = {MMD_val}")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
synths = model.sample(500, conditions)
MMD_val = MMD(synths, X_test.numpy())
axes[0].scatter(X_test[:,0], X_test[:,1])
axes[0].set_title("Test Data")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[1].scatter(synths[:,0], synths[:,1])
axes[1].set_title(f"Synthetic Data, MMD = {MMD_val}")
axes[1].set_xlabel("X")
plt.tight_layout()
plt.show()
0%| | 0/1000 [00:00<?, ?it/s]
Once again, the model captures the conditional distributions very well, though we observe some bleeding between the modes belonging to the same class. Once again the MMD values are higher than the best model from above, and in this case also visually the marginal distribution does not quite reach the quality of the best model from above. But as for the two moons, we trained the model on 1000 points and epochs only, and it is likely that the model would outperform the unconditional counterpart for 10000 points and epochs.
# hidden_size = 16
# n_blocks = 4
hidden_size = 64
n_blocks = 8
n_epochs = 1000
digits_X, digits_y = load_digits(return_X_y=True)
X_train = torch.tensor(digits_X[:int(0.8*len(digits_X))]).to(torch.float)
X_eval = torch.tensor(digits_X[int(0.8*len(digits_X)):int(0.9*len(digits_X))]).to(torch.float)
X_test = torch.tensor(digits_X[int(0.9*len(digits_X)):]).to(torch.float)
y_train = torch.tensor(digits_y[:int(0.8*len(digits_y))]).to(torch.float)
y_eval = torch.tensor(digits_y[int(0.8*len(digits_y)):int(0.9*len(digits_y))]).to(torch.float)
y_test = torch.tensor(digits_y[int(0.9*len(digits_y)):]).to(torch.float)
model = RealNVP(64, hidden_size, n_blocks)
train_inn(model, X_train, n_epochs, plot_loss=True, X_eval=X_eval)
model.eval()
with torch.no_grad():
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
for j in range(4):
z1_ind, z2_ind = np.random.choice(64, 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(model.encode(X_train).detach().numpy()[:,z1_ind], model.encode(X_train).detach().numpy()[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
plt.tight_layout()
plt.show()
fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharex=True, sharey=True)
synths = model.sample(1000).numpy()
MMD_val = MMD(synths, X_test.numpy())
for i in range(4):
for j in range(4):
if j == 0:
ind = np.random.choice(len(X_test))
axes[i, j].imshow(X_test[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
else:
ind = np.random.choice(len(synths))
axes[i, j].imshow(synths[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
axes[0, 1].set_title(f"Synthetic Data")
axes[0, 2].set_title(f"Synthetic Data")
axes[0, 3].set_title(f"Synthetic Data")
axes[0, 0].set_title("Test Data")
plt.tight_layout()
plt.show()
print(f"MMD = {MMD_val}")
0%| | 0/1000 [00:00<?, ?it/s]
MMD = 1.3231708300932964e-17
We plot 16 randomly selected 2D projections of the code distribution, and see that all 16 of them follow a standard normal very closely. The generated data is not very good though. It is possible to identify around half of the generated images as a number, while the other half could at least have 2 different labels.
class BottleneckRealNVP(nn.Module):
def __init__(self, input_size, hidden_size, blocks, k):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.blocks = blocks
self.k = k
self.couplings = nn.Sequential(*[
Block(self.input_size, self.hidden_size, i == self.blocks - 1)
for i in range(self.blocks)
])
self.sampler = MultivariateNormal(torch.zeros(self.input_size), torch.eye(self.input_size))
def encode(self, x):
for b in self.couplings:
x = b.encode(x)
return x
def decode(self, x):
for b in self.couplings[::-1]:
x = b.decode(x)
return x
def loss(self, x):
s_vals = torch.zeros(len(x))
x_hat, s_wig = self.couplings[0].loss_fwd(x)
s_vals += torch.sum(s_wig, dim=1)
for b in self.couplings[1:]:
x_hat, s_wig = b.loss_fwd(x_hat)
s_vals += torch.sum(s_wig, dim=1)
f_x_sq = torch.sum(x_hat * x_hat, dim=1) / 2
codes = torch.zeros_like(x_hat)
codes[:, :self.k] += x_hat[:, :self.k]
recs = self.couplings[-1].decode(codes)
for b in self.couplings[::-1][1:]:
recs = b.decode(recs)
return torch.mean(f_x_sq - s_vals) + torch.mean(torch.abs(recs - x))
def sample_k(self, n_samples):
with torch.no_grad():
codes = self.sampler.sample((n_samples,))
codes[:, self.k:] = 0
generated = self.decode(codes)
return generated
def sample_rest(self, n_samples):
with torch.no_grad():
codes = self.sampler.sample((n_samples,))
codes[:,:self.k] = 0
generated = self.decode(codes)
return generated
def print(self):
for name, param in self.named_parameters():
print(f"{name}, {param.shape}, grad: {param.requires_grad}")
# hidden_size = 16
# n_blocks = 4
hidden_size = 64
n_blocks = 8
k = 8
n_epochs = 5000
digits_X, digits_y = load_digits(return_X_y=True)
X_train = torch.tensor(digits_X[:int(0.8*len(digits_X))]).to(torch.float)
X_eval = torch.tensor(digits_X[int(0.8*len(digits_X)):int(0.9*len(digits_X))]).to(torch.float)
X_test = torch.tensor(digits_X[int(0.9*len(digits_X)):]).to(torch.float)
y_train = torch.tensor(digits_y[:int(0.8*len(digits_y))]).to(torch.float)
y_eval = torch.tensor(digits_y[int(0.8*len(digits_y)):int(0.9*len(digits_y))]).to(torch.float)
y_test = torch.tensor(digits_y[int(0.9*len(digits_y)):]).to(torch.float)
model = BottleneckRealNVP(64, hidden_size, n_blocks, k)
train_inn(model, X_train, n_epochs, plot_loss=True, X_eval=X_eval)
model.eval()
with torch.no_grad():
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
for j in range(4):
if j == 0:
z1_ind, z2_ind = np.random.choice(k, 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(model.encode(X_train).detach().numpy()[:,z1_ind], model.encode(X_train).detach().numpy()[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
else:
z1_ind, z2_ind = np.random.choice(range(k, 64), 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(model.encode(X_train).detach().numpy()[:,z1_ind], model.encode(X_train).detach().numpy()[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
plt.tight_layout()
plt.show()
print("Generated with unimportant dimensions set to 0")
fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharex=True, sharey=True)
synths = model.sample_k(1000).numpy()
MMD_val = MMD(synths, X_test.numpy())
for i in range(4):
for j in range(4):
if j == 0:
ind = np.random.choice(len(X_test))
axes[i, j].imshow(X_test[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
else:
ind = np.random.choice(len(synths))
axes[i, j].imshow(synths[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
axes[0, 1].set_title(f"Synthetic Data")
axes[0, 2].set_title(f"Synthetic Data")
axes[0, 3].set_title(f"Synthetic Data")
axes[0, 0].set_title("Test Data")
plt.tight_layout()
plt.show()
print("Generated with bottleneck dimensions set to 0")
fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharex=True, sharey=True)
synths = model.sample_rest(1000).numpy()
MMD_val = MMD(synths, X_test.numpy())
for i in range(4):
for j in range(4):
if j == 0:
ind = np.random.choice(len(X_test))
axes[i, j].imshow(X_test[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
else:
ind = np.random.choice(len(synths))
axes[i, j].imshow(synths[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
axes[0, 1].set_title(f"Synthetic Data")
axes[0, 2].set_title(f"Synthetic Data")
axes[0, 3].set_title(f"Synthetic Data")
axes[0, 0].set_title("Test Data")
plt.tight_layout()
plt.show()
print(f"MMD = {MMD_val}")
0%| | 0/5000 [00:00<?, ?it/s]
Generated with unimportant dimensions set to 0
Generated with bottleneck dimensions set to 0
MMD = 1.3231708300932964e-17
Once again, we plot the code distribution (the 4 panels on the left are the bottleneck, the other 12 are the "unimportant" dimensions), and observe that all of the follow the standard normal well.
The generated data is still not very convincing, but we can slightly see the effects of the two different sampling methods.
With the "unimportant" dimensions set to 0, the model seems to sample mroe or less a mean image, the numbers are usually slimmer and some of them can be identified as 9, 0, 5 and 1.
With the bottleneck set to 0, the images are a lot more diverse or random.
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter
# define a functor to downsample images
class DownsampleTransform():
def __init__(self, target_shape, algorithm=Image.Resampling.LANCZOS):
self.width, self.height = target_shape
self.algorithm = algorithm
def __call__(self, img):
img = img.resize((self.width+2, self.height+2), self.algorithm)
img = img.crop((1, 1, self.width+1, self.height+1))
return img
# concatenate a few transforms
transform = transforms.Compose([
DownsampleTransform(target_shape=(8, 8)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])
# download MNIST
mnist_dataset = datasets.MNIST(
root="./data",
train=True,
transform=transform,
download=True
)
# create a DataLoader that server minibatches of size 100
data_loader = DataLoader(mnist_dataset, batch_size=100, shuffle=True)
# visualize the first batch of downsampled MNIST images
def show_first_batch(data_loader):
for batch in data_loader:
x, y = batch
fig = plt.figure(figsize=(10, 10))
for i, img in enumerate(x):
ax = fig.add_subplot(10, 10, i+1)
ax.imshow(img.reshape(8, 8), cmap="gray")
ax.axis("off")
break
show_first_batch(data_loader)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:03<00:00, 2730914.94it/s]
Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<?, ?it/s]
Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 8867345.96it/s]
Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 915978.88it/s]
Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw
# hidden_size = 16
# n_blocks = 4
hidden_size = 64
n_blocks = 8
k = 8
# n_epochs = 5000
n_epochs = 100
model = BottleneckRealNVP(64, hidden_size, n_blocks, k)
train_inn(model, None, n_epochs, plot_loss=True, data_loader=data_loader)
model.eval()
with torch.no_grad():
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
for j in range(4):
if j == 0:
X_to_encode = torch.zeros(10, 100, 64)
for l, (x, y) in enumerate(data_loader):
X_to_encode[i] = x.reshape(100, 64)
if l >= 9:
break
X_to_encode = X_to_encode.reshape(1000, 64)
encodings = model.encode(X_to_encode).detach().numpy()
z1_ind, z2_ind = np.random.choice(k, 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(encodings[:,z1_ind], encodings[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
else:
z1_ind, z2_ind = np.random.choice(range(k, 64), 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(encodings[:,z1_ind], encodings[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
plt.tight_layout()
plt.show()
print("Generated with unimportant dimensions set to 0")
fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharex=True, sharey=True)
synths = model.sample_k(1000).numpy()
for i in range(4):
for j in range(4):
if j == 0:
for x, y in data_loader:
ind = np.random.choice(100)
axes[i, j].imshow(x[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
break
else:
ind = np.random.choice(len(synths))
axes[i, j].imshow(synths[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
axes[0, 1].set_title(f"Synthetic Data")
axes[0, 2].set_title(f"Synthetic Data")
axes[0, 3].set_title(f"Synthetic Data")
axes[0, 0].set_title("Test Data")
plt.tight_layout()
plt.show()
print("Generated with bottleneck dimensions set to 0")
fig, axes = plt.subplots(4, 4, figsize=(12, 12), sharex=True, sharey=True)
synths = model.sample_rest(1000).numpy()
for i in range(4):
for j in range(4):
if j == 0:
for x, y in data_loader:
ind = np.random.choice(100)
axes[i, j].imshow(x[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
break
else:
ind = np.random.choice(len(synths))
axes[i, j].imshow(synths[ind].reshape(8, 8), cmap="gray")
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
axes[0, 1].set_title(f"Synthetic Data")
axes[0, 2].set_title(f"Synthetic Data")
axes[0, 3].set_title(f"Synthetic Data")
axes[0, 0].set_title("Test Data")
plt.tight_layout()
plt.show()
0%| | 0/100 [00:00<?, ?it/s]
Generated with unimportant dimensions set to 0
Generated with bottleneck dimensions set to 0
Using a bigger training set greatly improved performance on the generated data, but many numbers are still not very readable. Usually, the sets of numbers 9, 5, 3 and 6, 0 are hard to distinguish, while numbers like 4 or 7 are almost not present. For the sampling method where the bottleneck is set to 0, the results look very similar.
class CondBottleneckRealNVP(nn.Module):
def __init__(self, input_size, hidden_size, blocks, condition_size, k):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.blocks = blocks
self.condition_size = condition_size
self.k = k
self.couplings = nn.Sequential(*[
CondBlock(self.input_size, self.hidden_size, self.condition_size, i == self.blocks - 1)
for i in range(self.blocks)
])
self.sampler = MultivariateNormal(torch.zeros(self.input_size), torch.eye(self.input_size))
def encode(self, x, condition):
for b in self.couplings:
x = b.encode(x, condition)
return x
def decode(self, x, condition):
for b in self.couplings[::-1]:
x = b.decode(x, condition)
return x
def loss(self, x, condition):
s_vals = torch.zeros(len(x))
x_hat, s_wig = self.couplings[0].loss_fwd(x, condition)
s_vals += torch.sum(s_wig, dim=1)
for b in self.couplings[1:]:
x_hat, s_wig = b.loss_fwd(x_hat, condition)
s_vals += torch.sum(s_wig, dim=1)
f_x_sq = torch.sum(x_hat * x_hat, dim=1) / 2
codes = torch.zeros_like(x_hat)
codes[:, :self.k] += x_hat[:, :self.k]
recs = self.couplings[-1].decode(codes, condition)
for b in self.couplings[::-1][1:]:
recs = b.decode(recs, condition)
return torch.mean(f_x_sq - s_vals) + torch.mean(torch.abs(recs - x))
def sample_k(self, n_samples, conditions):
with torch.no_grad():
generated = np.zeros((len(conditions), n_samples, self.input_size))
for i, condition in enumerate(conditions):
cond_stacked = torch.tensor([condition.tolist()] * n_samples).to(torch.float)
codes = self.sampler.sample((n_samples,))
codes[:, self.k:] = 0
generated[i] = self.decode(codes, cond_stacked)
return generated.reshape(len(conditions) * n_samples, self.input_size)
def sample_rest(self, n_samples, conditions):
with torch.no_grad():
generated = np.zeros((len(conditions), n_samples, self.input_size))
for i, condition in enumerate(conditions):
cond_stacked = torch.tensor([condition.tolist()] * n_samples).to(torch.float)
codes = self.sampler.sample((n_samples,))
codes[:, :self.k] = 0
generated[i] = self.decode(codes, cond_stacked)
return generated.reshape(len(conditions) * n_samples, self.input_size)
def print(self):
for name, param in self.named_parameters():
print(f"{name}, {param.shape}, grad: {param.requires_grad}")
# hidden_size = 16
# n_blocks = 4
hidden_size = 64
n_blocks = 8
k = 8
# n_epochs = 5000
n_epochs = 100
model = CondBottleneckRealNVP(64, hidden_size, n_blocks, 10, k)
train_cinn(model, None, None, n_epochs, plot_loss=True, data_loader=data_loader)
conditions = nn.functional.one_hot(torch.arange(10).to(torch.int64))
model.eval()
with torch.no_grad():
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
for j in range(4):
if j == 0:
X_to_encode = torch.zeros(10, 100, 64)
y_to_encode = torch.zeros(10, 100, 10)
for l, (x, y) in enumerate(data_loader):
X_to_encode[i] = x.reshape(100, 64)
y_to_encode[i] = nn.functional.one_hot(y.to(torch.int64), 10)
if l >= 9:
break
X_to_encode = X_to_encode.reshape(1000, 64)
y_to_encode = y_to_encode.reshape(1000, 10)
encodings = model.encode(X_to_encode, y_to_encode).detach().numpy()
z1_ind, z2_ind = np.random.choice(k, 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(encodings[:,z1_ind], encodings[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
else:
z1_ind, z2_ind = np.random.choice(range(k, 64), 2, replace=False)
gaussians = np.random.multivariate_normal(np.zeros(2), np.eye(2), 1000)
axes[i, j].scatter(encodings[:,z1_ind], encodings[:,z2_ind], label="Encodings", color="C0")
axes[i, j].scatter(gaussians[:,0], gaussians[:,1], label="True Standart Normal", color="C1", marker="x", alpha=0.8)
axes[i, j].set_title(f"Latent Space: $z_{{{z1_ind}}} - z_{{{z2_ind}}}$")
axes[i, j].set_xlabel(f"$z_{{{z1_ind}}}$")
axes[i, j].set_ylabel(f"$z_{{{z2_ind}}}$")
plt.tight_layout()
plt.show()
print("Generated with unimportant dimensions set to 0")
synths = model.sample_k(10, conditions)
fig = plt.figure(figsize=(10, 10))
for i, img in enumerate(synths):
ax = fig.add_subplot(10, 10, i+1)
ax.imshow(img.reshape(8, 8), cmap="gray")
ax.axis("off")
ax.set_title(f"{i//10}")
plt.tight_layout()
plt.show()
print("Generated with bottleneck dimensions set to 0")
synths = model.sample_rest(10, conditions)
fig = plt.figure(figsize=(10, 10))
for i, img in enumerate(synths):
ax = fig.add_subplot(10, 10, i+1)
ax.imshow(img.reshape(8, 8), cmap="gray")
ax.axis("off")
ax.set_title(f"{i//10}")
plt.tight_layout()
plt.show()
0%| | 0/100 [00:00<?, ?it/s]
Generated with unimportant dimensions set to 0
Generated with bottleneck dimensions set to 0
The conditional model worked a lot better. When setting the "uniportant" dimensions to 0, the model produces very convincing and easy to read results, which follow the given condition. In the 100 plotted synthetic numbers, we only found 1 image, that we would not be able to identify correctly (the 2-nd "3").
When setting the bottleneck to 0, the results indeed vary a lot more. In the training set, the "0" seem to be mostly tilted towards the right, but in this mode the model is capable of sampling upright "0" and also some tilted to the left. For "1" we also observe a broader range of angles. For "2" we notice variations in the sizes of the different components (overall small, long diagonal, big or no loop in the lower left). Similar observations can be made for every digit. The more creative way of sampling comes at the cost of more numbers being unidentifiable.